{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 注意力评分函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在02里，用了高斯核来对查询和键之间的关系建模\n",
    "\n",
    "其中，可以将高斯核的指数部分视为注意力评分函数（attention scoring function），简称评分函数（Scoreing function)\n",
    "\n",
    "通过把这个函数的输出结果输入到softmax函数中进行运算，通过这个步骤之后，我们可以得到与键所对应值的概率分布（即注意力权重）。\n",
    "\n",
    "最后注意力汇聚的输出就是基于这些注意力权重的值的加权和\n",
    "\n",
    "从宏观来看，可以使用上述算法来实现注意力机制框架\n",
    "\n",
    "![img](../img/attention-output.svg)\n",
    "\n",
    "用数学语言描述，假设有一个查询${q \\in R^q}$和m个键值对 ${(k_1,v_1),...,(k_m,v_m)}$，其中${k_i \\in R^k}$，${v_i \\in R^v}$。注意力汇聚函数${ f }$就被表示成值的加权和：  \n",
    "\n",
    "$ {f(\\mathbf{q}, (\\mathbf{k}_1, \\mathbf{v}_1), \\ldots, (\\mathbf{k}_m, \\mathbf{v}_m)) = \\sum_{i=1}^m \\alpha(\\mathbf{q}, \\mathbf{k}_i) \\mathbf{v}_i \\in \\mathbb{R}^v,} $  \n",
    "\n",
    "其中查询${q}$和键${k_i}$的注意力权重（标量）是通过注意力函数${\\alpha}$将两个向量映射成标量，再经过softmax运算得到的：  \n",
    "\n",
    "${\\alpha(\\mathbf{q}, \\mathbf{k}_i) = \\mathrm{softmax}(a(\\mathbf{q}, \\mathbf{k}_i)) = \\frac{\\exp(a(\\mathbf{q}, \\mathbf{k}_i))}{\\sum_{j=1}^m \\exp(a(\\mathbf{q}, \\mathbf{k}_j))} \\in \\mathbb{R}.}$\n",
    "\n",
    "正如我们所看到的， 选择不同的注意力评分函数${\\alpha}$会导致不同的注意力汇聚操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "import d2lzh_pytorch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 遮蔽softmax操作\n",
    "\n",
    "正如上边提到的，softmax运算用于输出一个概率分布作为注意力权重。在某些情况下，并非所有的值都应该被纳入注意力汇聚中。\n",
    "\n",
    "比如，为了高效处理小批量数据集，某些文本序列被填充了没有意义的特殊词元。为了仅将有意义的词元去获取注意力汇聚，可以指定一个有效序列长度（即词元的个数），以便在计算softmax时过滤掉超过指定范围的位置。通过这种方式，我们可以在下面的masked_softmax函数中实现这样的遮蔽softmax操作（masked softmax operation），其中任何超出有效长度的位置都被遮盖并置为0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 日志类\n",
    "class logger:\n",
    "    def __init__(self, flag):\n",
    "        self.flag=flag\n",
    "\n",
    "    def print(self,*str):\n",
    "        if(self.flag is True):\n",
    "            print(\"INFO:\",*str)\n",
    "\n",
    "    def print_arg_shape(self,queries,keys,values,vaild_lens):\n",
    "        print(\"====================== Arg.Shape ========================\")\n",
    "        print(\"queris.shape\",queries.shape)\n",
    "        print(\"keys.shape\",keys.shape)\n",
    "        print(\"values.shape\",values.shape)\n",
    "        print(\"vaild_lens\",vaild_lens.shape)\n",
    "        print(\"====================== Arg.Shape ========================\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "def masked_softmax(X, vaild_lens,flag=False):\n",
    "    log=logger(flag)\n",
    "    \"\"\"通过在最后一个轴上屏蔽元素来执行softmax操作\"\"\"\n",
    "    \n",
    "    log.print(\"INFO:X_shape\" , X.shape)\n",
    "    log.print(\"INFO:X_value\" , X)\n",
    "    \n",
    "    if vaild_lens is None:\n",
    "        return nn.functional.softmax(X,dim=-1)\n",
    "    else:\n",
    "        shape=X.shape\n",
    "        if vaild_lens.dim()==1:\n",
    "            # eg : 将 torch.tensor([2, 3]) --> 将 torch.tensor([2, 2, 3, 3])\n",
    "            vaild_lens=torch.repeat_interleave(vaild_lens,shape[1])\n",
    "        else:\n",
    "            vaild_lens=vaild_lens.reshape(-1)\n",
    "        \n",
    "        log.print(\"vaild_lens_shape\" , vaild_lens.shape)\n",
    "        log.print( \"vaild_lens_values\" , vaild_lens)\n",
    "        \n",
    "        log.print(\"X.reshape(-1, shape[-1])\" , X.reshape(-1, shape[-1]))\n",
    "        log.print(\"X.reshape(-1, shape[-1]).shape\" , X.reshape(-1, shape[-1]).shape)\n",
    "        \n",
    "        log.print(\"Before sequence_mask X.value\",X)\n",
    "\n",
    "        # \"\"\"Mask irrelevant entries in sequences.\"\"\"\n",
    "        X=d2l.sequence_mask(X.reshape(-1, shape[-1]), vaild_lens,value=-1e6)\n",
    "        log.print(\"after sequence_mask X.value\",X)\n",
    "        \n",
    "        # Defined in file: ./chapter_recurrent-modern/seq2seq.md\n",
    "        # def sequence_mask(X, valid_len, value=0):\n",
    "        #     \"\"\"Mask irrelevant entries in sequences.\"\"\"\n",
    "        #     maxlen = X.size(1)\n",
    "        #     mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]\n",
    "        #     X[~mask] = value\n",
    "        # return X\n",
    "\n",
    "\n",
    "        return nn.functional.softmax(X.reshape(shape), dim=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![img](../img/attention_mask.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.4194, 0.5806, 0.0000, 0.0000],\n",
       "         [0.3499, 0.6501, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.3180, 0.2603, 0.4217, 0.0000],\n",
       "         [0.2563, 0.2481, 0.4956, 0.0000]]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# mask_softmax(X, vaild_lens)\n",
    "masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]), flag=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.2057, 0.4656, 0.3287, 0.0000]],\n",
       "\n",
       "        [[0.4520, 0.5480, 0.0000, 0.0000],\n",
       "         [0.3158, 0.1928, 0.3064, 0.1851]]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加性注意力\n",
    "\n",
    "一般来说，当查询和键是不同的矢量时候，可以使用加性注意力作为评分函数。给定查询${ q \\in R^q }$和键${ k \\in R^k}$ 加性注意力（additive attention）的评分函数为\n",
    "\n",
    "${a(\\mathbf q, \\mathbf k) = \\mathbf w_v^\\top \\text{tanh}(\\mathbf W_q\\mathbf q + \\mathbf W_k \\mathbf k) \\in \\mathbb{R},}$\n",
    "\n",
    "其中可学习的参数是${W_q \\in R^{h \\times q}、W_k \\in R^{h \\times k} }$ 和 ${w \\in R^h}$，如上公式所示，将查询和键连接起来后输入一个多层感知机（MLP）中，感知机包含一个隐藏层，其隐藏单元数是一个超参数${h}$，通过使用 ${tanh}$作为激活函数，并且禁用偏置项，我们将在下面实现加性注意力。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AdditiveAttention(nn.Module):\n",
    "    # 加性注意力\n",
    "    def __init__(self,key_size,query_size,num_hiddens,dropout,debug=False,**kwargs):\n",
    "        super(AdditiveAttention,self).__init__(**kwargs)\n",
    "        \n",
    "        # AdditiveAttention(key_size=2,query_size=20,num_hiddens=8,dropout=0.1,debug=True)\n",
    "        \n",
    "        self.w_k=nn.Linear(key_size,num_hiddens,bias=False)\n",
    "        self.w_q=nn.Linear(query_size,num_hiddens,bias=False)\n",
    "        self.w_v=nn.Linear(num_hiddens,1,bias=False)\n",
    "        self.dropout=nn.Dropout(dropout)\n",
    "        self.log=logger(debug)\n",
    "\n",
    "    def forward(self,queries,keys,values,vaild_lens):\n",
    "        \n",
    "        self.log.print(\"====================== Arg.Shape ========================\")\n",
    "        self.log.print(\"queris.shape\",queries.shape)\n",
    "        self.log.print(\"keys.shape\",keys.shape)\n",
    "        self.log.print(\"values.shape\",values.shape)\n",
    "        self.log.print(\"vaild_lens\",vaild_lens.shape)\n",
    "        self.log.print(\"====================== Arg.Shape ========================\")\n",
    "        \n",
    "        # INFO: ====================== Arg.Shape ========================\n",
    "        # INFO: queris.shape torch.Size([2, 1, 20])\n",
    "        # INFO: keys.shape torch.Size([2, 10, 2])\n",
    "        # INFO: values.shape torch.Size([2, 10, 4])\n",
    "        # INFO: vaild_lens torch.Size([2])\n",
    "        # INFO: ====================== Arg.Shape ========================\n",
    "        \n",
    "        queries, keys=self.w_q(queries), self.w_k(keys)\n",
    "        \n",
    "        self.log.print()\n",
    "        self.log.print(\"====================== self.w_q & self.w_k ========================\")\n",
    "        self.log.print(\"queries\",queries.shape)\n",
    "        self.log.print(\"keys.shape\",keys.shape)\n",
    "        self.log.print(\"====================== self.w_q & self.w_k ========================\")\n",
    "        self.log.print()\n",
    "\n",
    "\n",
    "        # 加性注意力机制\n",
    "        features=queries.unsqueeze(2) + keys.unsqueeze(1)\n",
    "        \n",
    "        \n",
    "        self.log.print()\n",
    "        self.log.print(\"====================== unsqueeze ========================\")\n",
    "        self.log.print(\"queries.unsqueeze(2)\",queries.unsqueeze(2).shape)\n",
    "        self.log.print(\"keys.unsqueeze(1)\",keys.unsqueeze(1).shape)\n",
    "        self.log.print(\"====================== unsqueeze ========================\")\n",
    "        self.log.print()\n",
    "        \n",
    "        \n",
    "        features=torch.tanh(features)\n",
    "        scores=self.w_v(features).squeeze(-1)\n",
    "        self.log.print(\"scores\",self.w_v(features).shape)\n",
    "\n",
    "        self.attention_weights=masked_softmax(scores,vaild_lens)\n",
    "\n",
    "        self.log.print(\"====================== finally ========================\")\n",
    "        self.log.print(\"self.attention_weights\",self.attention_weights.shape)\n",
    "        self.log.print(\"self.dropout(self.atten..)\",self.dropout(self.attention_weights))\n",
    "        self.log.print(\"====================== finally ========================\")\n",
    "        \n",
    "        return torch.bmm(self.dropout(self.attention_weights),values)\n",
    "\n",
    "    def get_linear_params(self):\n",
    "        self.log.print(self.w_q.state_dict())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "queries,keys=torch.normal(0,1,(2,1,20)), torch.ones((2,10,2))\n",
    "values=torch.arange(40,dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)\n",
    "vaild_lens=torch.tensor([2,6])\n",
    "attention=AdditiveAttention(key_size=2,query_size=20,num_hiddens=8,dropout=0.1)\n",
    "attention.eval()\n",
    "attention(queries,keys,values,vaild_lens)\n",
    "\n",
    "attention.get_linear_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dropout\n",
    "# Each channel will be zeroed out independently on every forward call.\n",
    "m = nn.Dropout(p=0.2)\n",
    "input = torch.randn(50, 10)\n",
    "output = m(input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "尽管加性注意力包含了可学习的参数，但由于本例子中每个键都是相同的，所以注意力权重是均匀的，由指定的有效长度决定。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"101.818906pt\" version=\"1.1\" viewBox=\"0 0 186.99575 101.818906\" width=\"186.99575pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <metadata>\r\n  <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\r\n   <cc:Work>\r\n    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\r\n    <dc:date>2021-11-27T09:58:34.943212</dc:date>\r\n    <dc:format>image/svg+xml</dc:format>\r\n    <dc:creator>\r\n     <cc:Agent>\r\n      <dc:title>Matplotlib v3.3.4, https://matplotlib.org/</dc:title>\r\n     </cc:Agent>\r\n    </dc:creator>\r\n   </cc:Work>\r\n  </rdf:RDF>\r\n </metadata>\r\n <defs>\r\n  <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n  <g id=\"patch_1\">\r\n   <path d=\"M -0 101.818906 \r\nL 186.99575 101.818906 \r\nL 186.99575 0 \r\nL -0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n  </g>\r\n  <g id=\"axes_1\">\r\n   <g id=\"patch_2\">\r\n    <path d=\"M 34.240625 59.13 \r\nL 145.840625 59.13 \r\nL 145.840625 36.81 \r\nL 34.240625 36.81 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n   </g>\r\n   <g clip-path=\"url(#p508ef81f23)\">\r\n    <image height=\"23\" id=\"image46d5ff7c7b\" transform=\"scale(1 -1)translate(0 -23)\" width=\"112\" x=\"34.240625\" xlink:href=\"data:image/png;base64,\r\niVBORw0KGgoAAAANSUhEUgAAAHAAAAAXCAYAAADTEcupAAAAeElEQVR4nO3YsQmAQBAF0X+aWp8tGFqDXZjZnC2YCV5uJnIcA/MKWBaGTbbcx/ZESZJxXnuv8NnQewH9Y0A4A8IZEM6AcAaEMyCcAeEMCGdAOAPClSVTk1/ofp0txurFC4QzIJwB4QwIZ0A4A8IZEM6AcAaEMyBcBc2IB/NA+hblAAAAAElFTkSuQmCC\" y=\"-36.13\"/>\r\n   </g>\r\n   <g id=\"matplotlib.axis_1\">\r\n    <g id=\"xtick_1\">\r\n     <g id=\"line2d_1\">\r\n      <defs>\r\n       <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"m5728a421ed\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n      </defs>\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"39.820625\" xlink:href=\"#m5728a421ed\" y=\"59.13\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_1\">\r\n      <!-- 0 -->\r\n      <g transform=\"translate(36.639375 73.728437)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"xtick_2\">\r\n     <g id=\"line2d_2\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"95.620625\" xlink:href=\"#m5728a421ed\" y=\"59.13\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_2\">\r\n      <!-- 5 -->\r\n      <g transform=\"translate(92.439375 73.728437)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 10.796875 72.90625 \r\nL 49.515625 72.90625 \r\nL 49.515625 64.59375 \r\nL 19.828125 64.59375 \r\nL 19.828125 46.734375 \r\nQ 21.96875 47.46875 24.109375 47.828125 \r\nQ 26.265625 48.1875 28.421875 48.1875 \r\nQ 40.625 48.1875 47.75 41.5 \r\nQ 54.890625 34.8125 54.890625 23.390625 \r\nQ 54.890625 11.625 47.5625 5.09375 \r\nQ 40.234375 -1.421875 26.90625 -1.421875 \r\nQ 22.3125 -1.421875 17.546875 -0.640625 \r\nQ 12.796875 0.140625 7.71875 1.703125 \r\nL 7.71875 11.625 \r\nQ 12.109375 9.234375 16.796875 8.0625 \r\nQ 21.484375 6.890625 26.703125 6.890625 \r\nQ 35.15625 6.890625 40.078125 11.328125 \r\nQ 45.015625 15.765625 45.015625 23.390625 \r\nQ 45.015625 31 40.078125 35.4375 \r\nQ 35.15625 39.890625 26.703125 39.890625 \r\nQ 22.75 39.890625 18.8125 39.015625 \r\nQ 14.890625 38.140625 10.796875 36.28125 \r\nz\r\n\" id=\"DejaVuSans-53\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-53\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"text_3\">\r\n     <!-- keys -->\r\n     <g transform=\"translate(78.685938 87.406562)scale(0.1 -0.1)\">\r\n      <defs>\r\n       <path d=\"M 9.078125 75.984375 \r\nL 18.109375 75.984375 \r\nL 18.109375 31.109375 \r\nL 44.921875 54.6875 \r\nL 56.390625 54.6875 \r\nL 27.390625 29.109375 \r\nL 57.625 0 \r\nL 45.90625 0 \r\nL 18.109375 26.703125 \r\nL 18.109375 0 \r\nL 9.078125 0 \r\nz\r\n\" id=\"DejaVuSans-107\"/>\r\n       <path d=\"M 56.203125 29.59375 \r\nL 56.203125 25.203125 \r\nL 14.890625 25.203125 \r\nQ 15.484375 15.921875 20.484375 11.0625 \r\nQ 25.484375 6.203125 34.421875 6.203125 \r\nQ 39.59375 6.203125 44.453125 7.46875 \r\nQ 49.3125 8.734375 54.109375 11.28125 \r\nL 54.109375 2.78125 \r\nQ 49.265625 0.734375 44.1875 -0.34375 \r\nQ 39.109375 -1.421875 33.890625 -1.421875 \r\nQ 20.796875 -1.421875 13.15625 6.1875 \r\nQ 5.515625 13.8125 5.515625 26.8125 \r\nQ 5.515625 40.234375 12.765625 48.109375 \r\nQ 20.015625 56 32.328125 56 \r\nQ 43.359375 56 49.78125 48.890625 \r\nQ 56.203125 41.796875 56.203125 29.59375 \r\nz\r\nM 47.21875 32.234375 \r\nQ 47.125 39.59375 43.09375 43.984375 \r\nQ 39.0625 48.390625 32.421875 48.390625 \r\nQ 24.90625 48.390625 20.390625 44.140625 \r\nQ 15.875 39.890625 15.1875 32.171875 \r\nz\r\n\" id=\"DejaVuSans-101\"/>\r\n       <path d=\"M 32.171875 -5.078125 \r\nQ 28.375 -14.84375 24.75 -17.8125 \r\nQ 21.140625 -20.796875 15.09375 -20.796875 \r\nL 7.90625 -20.796875 \r\nL 7.90625 -13.28125 \r\nL 13.1875 -13.28125 \r\nQ 16.890625 -13.28125 18.9375 -11.515625 \r\nQ 21 -9.765625 23.484375 -3.21875 \r\nL 25.09375 0.875 \r\nL 2.984375 54.6875 \r\nL 12.5 54.6875 \r\nL 29.59375 11.921875 \r\nL 46.6875 54.6875 \r\nL 56.203125 54.6875 \r\nz\r\n\" id=\"DejaVuSans-121\"/>\r\n       <path d=\"M 44.28125 53.078125 \r\nL 44.28125 44.578125 \r\nQ 40.484375 46.53125 36.375 47.5 \r\nQ 32.28125 48.484375 27.875 48.484375 \r\nQ 21.1875 48.484375 17.84375 46.4375 \r\nQ 14.5 44.390625 14.5 40.28125 \r\nQ 14.5 37.15625 16.890625 35.375 \r\nQ 19.28125 33.59375 26.515625 31.984375 \r\nL 29.59375 31.296875 \r\nQ 39.15625 29.25 43.1875 25.515625 \r\nQ 47.21875 21.78125 47.21875 15.09375 \r\nQ 47.21875 7.46875 41.1875 3.015625 \r\nQ 35.15625 -1.421875 24.609375 -1.421875 \r\nQ 20.21875 -1.421875 15.453125 -0.5625 \r\nQ 10.6875 0.296875 5.421875 2 \r\nL 5.421875 11.28125 \r\nQ 10.40625 8.6875 15.234375 7.390625 \r\nQ 20.0625 6.109375 24.8125 6.109375 \r\nQ 31.15625 6.109375 34.5625 8.28125 \r\nQ 37.984375 10.453125 37.984375 14.40625 \r\nQ 37.984375 18.0625 35.515625 20.015625 \r\nQ 33.0625 21.96875 24.703125 23.78125 \r\nL 21.578125 24.515625 \r\nQ 13.234375 26.265625 9.515625 29.90625 \r\nQ 5.8125 33.546875 5.8125 39.890625 \r\nQ 5.8125 47.609375 11.28125 51.796875 \r\nQ 16.75 56 26.8125 56 \r\nQ 31.78125 56 36.171875 55.265625 \r\nQ 40.578125 54.546875 44.28125 53.078125 \r\nz\r\n\" id=\"DejaVuSans-115\"/>\r\n      </defs>\r\n      <use xlink:href=\"#DejaVuSans-107\"/>\r\n      <use x=\"54.285156\" xlink:href=\"#DejaVuSans-101\"/>\r\n      <use x=\"115.808594\" xlink:href=\"#DejaVuSans-121\"/>\r\n      <use x=\"174.988281\" xlink:href=\"#DejaVuSans-115\"/>\r\n     </g>\r\n    </g>\r\n   </g>\r\n   <g id=\"matplotlib.axis_2\">\r\n    <g id=\"ytick_1\">\r\n     <g id=\"line2d_3\">\r\n      <defs>\r\n       <path d=\"M 0 0 \r\nL -3.5 0 \r\n\" id=\"m6fe36a5a71\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n      </defs>\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"34.240625\" xlink:href=\"#m6fe36a5a71\" y=\"42.39\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_4\">\r\n      <!-- 0 -->\r\n      <g transform=\"translate(20.878125 46.189219)scale(0.1 -0.1)\">\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_2\">\r\n     <g id=\"line2d_4\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"34.240625\" xlink:href=\"#m6fe36a5a71\" y=\"53.55\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_5\">\r\n      <!-- 1 -->\r\n      <g transform=\"translate(20.878125 57.349219)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 12.40625 8.296875 \r\nL 28.515625 8.296875 \r\nL 28.515625 63.921875 \r\nL 10.984375 60.40625 \r\nL 10.984375 69.390625 \r\nL 28.421875 72.90625 \r\nL 38.28125 72.90625 \r\nL 38.28125 8.296875 \r\nL 54.390625 8.296875 \r\nL 54.390625 0 \r\nL 12.40625 0 \r\nz\r\n\" id=\"DejaVuSans-49\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-49\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"text_6\">\r\n     <!-- queries -->\r\n     <g transform=\"translate(14.798437 66.515312)rotate(-90)scale(0.1 -0.1)\">\r\n      <defs>\r\n       <path d=\"M 14.796875 27.296875 \r\nQ 14.796875 17.390625 18.875 11.75 \r\nQ 22.953125 6.109375 30.078125 6.109375 \r\nQ 37.203125 6.109375 41.296875 11.75 \r\nQ 45.40625 17.390625 45.40625 27.296875 \r\nQ 45.40625 37.203125 41.296875 42.84375 \r\nQ 37.203125 48.484375 30.078125 48.484375 \r\nQ 22.953125 48.484375 18.875 42.84375 \r\nQ 14.796875 37.203125 14.796875 27.296875 \r\nz\r\nM 45.40625 8.203125 \r\nQ 42.578125 3.328125 38.25 0.953125 \r\nQ 33.9375 -1.421875 27.875 -1.421875 \r\nQ 17.96875 -1.421875 11.734375 6.484375 \r\nQ 5.515625 14.40625 5.515625 27.296875 \r\nQ 5.515625 40.1875 11.734375 48.09375 \r\nQ 17.96875 56 27.875 56 \r\nQ 33.9375 56 38.25 53.625 \r\nQ 42.578125 51.265625 45.40625 46.390625 \r\nL 45.40625 54.6875 \r\nL 54.390625 54.6875 \r\nL 54.390625 -20.796875 \r\nL 45.40625 -20.796875 \r\nz\r\n\" id=\"DejaVuSans-113\"/>\r\n       <path d=\"M 8.5 21.578125 \r\nL 8.5 54.6875 \r\nL 17.484375 54.6875 \r\nL 17.484375 21.921875 \r\nQ 17.484375 14.15625 20.5 10.265625 \r\nQ 23.53125 6.390625 29.59375 6.390625 \r\nQ 36.859375 6.390625 41.078125 11.03125 \r\nQ 45.3125 15.671875 45.3125 23.6875 \r\nL 45.3125 54.6875 \r\nL 54.296875 54.6875 \r\nL 54.296875 0 \r\nL 45.3125 0 \r\nL 45.3125 8.40625 \r\nQ 42.046875 3.421875 37.71875 1 \r\nQ 33.40625 -1.421875 27.6875 -1.421875 \r\nQ 18.265625 -1.421875 13.375 4.4375 \r\nQ 8.5 10.296875 8.5 21.578125 \r\nz\r\nM 31.109375 56 \r\nz\r\n\" id=\"DejaVuSans-117\"/>\r\n       <path d=\"M 41.109375 46.296875 \r\nQ 39.59375 47.171875 37.8125 47.578125 \r\nQ 36.03125 48 33.890625 48 \r\nQ 26.265625 48 22.1875 43.046875 \r\nQ 18.109375 38.09375 18.109375 28.8125 \r\nL 18.109375 0 \r\nL 9.078125 0 \r\nL 9.078125 54.6875 \r\nL 18.109375 54.6875 \r\nL 18.109375 46.1875 \r\nQ 20.953125 51.171875 25.484375 53.578125 \r\nQ 30.03125 56 36.53125 56 \r\nQ 37.453125 56 38.578125 55.875 \r\nQ 39.703125 55.765625 41.0625 55.515625 \r\nz\r\n\" id=\"DejaVuSans-114\"/>\r\n       <path d=\"M 9.421875 54.6875 \r\nL 18.40625 54.6875 \r\nL 18.40625 0 \r\nL 9.421875 0 \r\nz\r\nM 9.421875 75.984375 \r\nL 18.40625 75.984375 \r\nL 18.40625 64.59375 \r\nL 9.421875 64.59375 \r\nz\r\n\" id=\"DejaVuSans-105\"/>\r\n      </defs>\r\n      <use xlink:href=\"#DejaVuSans-113\"/>\r\n      <use x=\"63.476562\" xlink:href=\"#DejaVuSans-117\"/>\r\n      <use x=\"126.855469\" xlink:href=\"#DejaVuSans-101\"/>\r\n      <use x=\"188.378906\" xlink:href=\"#DejaVuSans-114\"/>\r\n      <use x=\"229.492188\" xlink:href=\"#DejaVuSans-105\"/>\r\n      <use x=\"257.275391\" xlink:href=\"#DejaVuSans-101\"/>\r\n      <use x=\"318.798828\" xlink:href=\"#DejaVuSans-115\"/>\r\n     </g>\r\n    </g>\r\n   </g>\r\n   <g id=\"patch_3\">\r\n    <path d=\"M 34.240625 59.13 \r\nL 34.240625 36.81 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n   <g id=\"patch_4\">\r\n    <path d=\"M 145.840625 59.13 \r\nL 145.840625 36.81 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n   <g id=\"patch_5\">\r\n    <path d=\"M 34.240625 59.13 \r\nL 145.840625 59.13 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n   <g id=\"patch_6\">\r\n    <path d=\"M 34.240625 36.81 \r\nL 145.840625 36.81 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n  </g>\r\n  <g id=\"axes_2\">\r\n   <g id=\"patch_7\">\r\n    <path clip-path=\"url(#p46468f8c41)\" d=\"M 152.815625 88.74 \r\nL 152.815625 88.421484 \r\nL 152.815625 7.518516 \r\nL 152.815625 7.2 \r\nL 156.892625 7.2 \r\nL 156.892625 7.518516 \r\nL 156.892625 88.421484 \r\nL 156.892625 88.74 \r\nz\r\n\" style=\"fill:#ffffff;stroke:#ffffff;stroke-linejoin:miter;stroke-width:0.01;\"/>\r\n   </g>\r\n   <image height=\"82\" id=\"image1a99743e6f\" transform=\"scale(1 -1)translate(0 -82)\" width=\"4\" x=\"153\" xlink:href=\"data:image/png;base64,\r\niVBORw0KGgoAAAANSUhEUgAAAAQAAABSCAYAAABzJnWUAAAAnklEQVR4nJ2SOw5CMQwEjfTuf1QqChT/aMnskwxJl9F414ry6Nez7etclmE7qCLIAfRonABZLBy1o2FJEBzxxdD/M7R2NFh7EzpmHBhSm3z1H/YAaDEk1OfVBfCPCXCAEmPJppJRg1Fs0QyCYGhmw4jRIMjkSMgIjBBDapvGfrcrTYwDYINxsweAM8NpvAtgCeDIYqiMSIsYUrs/sdkHBN7KuQYqCJAAAAAASUVORK5CYII=\" y=\"-6\"/>\r\n   <g id=\"matplotlib.axis_3\"/>\r\n   <g id=\"matplotlib.axis_4\">\r\n    <g id=\"ytick_3\">\r\n     <g id=\"line2d_5\">\r\n      <defs>\r\n       <path d=\"M 0 0 \r\nL 3.5 0 \r\n\" id=\"mfdb9a37789\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n      </defs>\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#mfdb9a37789\" y=\"88.74\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_7\">\r\n      <!-- 0.0 -->\r\n      <g transform=\"translate(163.892625 92.539219)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 10.6875 12.40625 \r\nL 21 12.40625 \r\nL 21 0 \r\nL 10.6875 0 \r\nz\r\n\" id=\"DejaVuSans-46\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\r\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_4\">\r\n     <g id=\"line2d_6\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#mfdb9a37789\" y=\"56.124\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_8\">\r\n      <!-- 0.2 -->\r\n      <g transform=\"translate(163.892625 59.923219)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 19.1875 8.296875 \r\nL 53.609375 8.296875 \r\nL 53.609375 0 \r\nL 7.328125 0 \r\nL 7.328125 8.296875 \r\nQ 12.9375 14.109375 22.625 23.890625 \r\nQ 32.328125 33.6875 34.8125 36.53125 \r\nQ 39.546875 41.84375 41.421875 45.53125 \r\nQ 43.3125 49.21875 43.3125 52.78125 \r\nQ 43.3125 58.59375 39.234375 62.25 \r\nQ 35.15625 65.921875 28.609375 65.921875 \r\nQ 23.96875 65.921875 18.8125 64.3125 \r\nQ 13.671875 62.703125 7.8125 59.421875 \r\nL 7.8125 69.390625 \r\nQ 13.765625 71.78125 18.9375 73 \r\nQ 24.125 74.21875 28.421875 74.21875 \r\nQ 39.75 74.21875 46.484375 68.546875 \r\nQ 53.21875 62.890625 53.21875 53.421875 \r\nQ 53.21875 48.921875 51.53125 44.890625 \r\nQ 49.859375 40.875 45.40625 35.40625 \r\nQ 44.1875 33.984375 37.640625 27.21875 \r\nQ 31.109375 20.453125 19.1875 8.296875 \r\nz\r\n\" id=\"DejaVuSans-50\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\r\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-50\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_5\">\r\n     <g id=\"line2d_7\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#mfdb9a37789\" y=\"23.508\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_9\">\r\n      <!-- 0.4 -->\r\n      <g transform=\"translate(163.892625 27.307219)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 37.796875 64.3125 \r\nL 12.890625 25.390625 \r\nL 37.796875 25.390625 \r\nz\r\nM 35.203125 72.90625 \r\nL 47.609375 72.90625 \r\nL 47.609375 25.390625 \r\nL 58.015625 25.390625 \r\nL 58.015625 17.1875 \r\nL 47.609375 17.1875 \r\nL 47.609375 0 \r\nL 37.796875 0 \r\nL 37.796875 17.1875 \r\nL 4.890625 17.1875 \r\nL 4.890625 26.703125 \r\nz\r\n\" id=\"DejaVuSans-52\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\r\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-52\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n   </g>\r\n   <g id=\"patch_8\">\r\n    <path d=\"M 152.815625 88.74 \r\nL 152.815625 88.421484 \r\nL 152.815625 7.518516 \r\nL 152.815625 7.2 \r\nL 156.892625 7.2 \r\nL 156.892625 7.518516 \r\nL 156.892625 88.421484 \r\nL 156.892625 88.74 \r\nz\r\n\" style=\"fill:none;stroke:#000000;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n  </g>\r\n </g>\r\n <defs>\r\n  <clipPath id=\"p508ef81f23\">\r\n   <rect height=\"22.32\" width=\"111.6\" x=\"34.240625\" y=\"36.81\"/>\r\n  </clipPath>\r\n  <clipPath id=\"p46468f8c41\">\r\n   <rect height=\"81.54\" width=\"4.077\" x=\"152.815625\" y=\"7.2\"/>\r\n  </clipPath>\r\n </defs>\r\n</svg>\r\n",
      "text/plain": [
       "<Figure size 180x180 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.show_heatmaps(attention.attention_weights.reshape((1,1,2,10)) , xlabel='keys',ylabel='queries')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Notes 数据分析分析的办法\n",
    "\n",
    "1. 通过Print来看数据的Shape变换\n",
    "2. 通过Debug，将变量以表达式的形式放入Expression\n",
    "3. 通过放在Pycharm Debug，然后通过Copy Values，来观察数据\n",
    "   * 输出不完整的话，解决办法有以下两种\n",
    "   * 设置代码 `torch.set_printoptions(threshold=np.inf)`\n",
    "   * 通过在Debug Console转换Tensor->Numpy，然后使用Pycharm的sciview来观察数据矩阵\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 缩放点积注意力\n",
    "\n",
    "使用点积可以得到计算效率更高的评分函数。但是点积操作要求查询和键具有相同的长度d。\n",
    "\n",
    "假设查询和键的所有元素都是独立的随机变量，并且都满足均值为0和方差为1.那么两个向量的点积的均值为0，方差为d。\n",
    "\n",
    "为确保无论向量长度如何，点积方差在不考虑向量长度的情况仍然是1，则可以使用缩放点积注意力（Scaled dot-product attention)\n",
    "\n",
    "${ a(\\mathbf q, \\mathbf k) = \\mathbf{q}^\\top \\mathbf{k}  /\\sqrt{d} }$\n",
    "\n",
    "将点积除以 ${ \\sqrt{d} }$。在实践中，我们通常以小批量的角度来提高效率，例如基于n个查询和m个键一值对计算注意力，其中查询的键的长度为d，值的长度为v，查询${Q \\in R^{n \\times d}}$、键 ${K \\in R^{m \\times d}}$和值${V \\in R^{m \\times v}}$的缩放点注意力是\n",
    "\n",
    "${\\mathrm{softmax}\\left(\\frac{\\mathbf Q \\mathbf K^\\top }{\\sqrt{d}}\\right) \\mathbf V \\in \\mathbb{R}^{n\\times v}.}$\n",
    "\n",
    "在下面的缩放点注意力的实现中，我们使用了dropout进行模型正则化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DotProductAttention(nn.Module):\n",
    "    \"\"\"缩放点积注意力\"\"\"\n",
    "    def __init__(self,dropout,debug = False,**kwargs):\n",
    "        super(DotProductAttention,self).__init__(**kwargs)\n",
    "        self.dropout=nn.Dropout(dropout)\n",
    "        self.log=logger(debug)\n",
    "\n",
    "    def forward(self,queries,keys,values,vaild_lens=None):\n",
    "        self.log.print_arg_shape(queries,keys,values,vaild_lens)\n",
    "        \n",
    "        d=queries.shape[-1]\n",
    "        self.log.print(\"d\",d)\n",
    "\n",
    "        self.log.print(\"keys\",keys)\n",
    "        self.log.print(\"keys.transpose.shape\",keys.transpose(1,2).shape)\n",
    "        \n",
    "        # 乘性注意力机制\n",
    "        scores=torch.bmm(queries,keys.transpose(1,2)) / math.sqrt(d)\n",
    "        \n",
    "        self.log.print(\"scores.shape\",scores.shape)\n",
    "        \n",
    "        self.attention_weights=masked_softmax(scores,vaild_lens)\n",
    "        return torch.bmm(self.dropout(self.attention_weights),values)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了演示上述的DotProductAttention类，我们使用了先前加性注意力例子中相同的键、值和有效长度。对于点积操作，令查询的特征维度与键的特征维度大小相同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================== Arg.Shape ========================\n",
      "queris.shape torch.Size([2, 1, 2])\n",
      "keys.shape torch.Size([2, 10, 2])\n",
      "values.shape torch.Size([2, 10, 4])\n",
      "vaild_lens torch.Size([2])\n",
      "====================== Arg.Shape ========================\n",
      "INFO: d 2\n",
      "INFO: keys tensor([[[1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.]],\n",
      "\n",
      "        [[1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.]]])\n",
      "INFO: keys.transpose.shape torch.Size([2, 2, 10])\n",
      "INFO: scores.shape torch.Size([2, 1, 10])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],\n",
       "\n",
       "        [[10.0000, 11.0000, 12.0000, 13.0000]]])"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# queries,keys=torch.normal(0,1,(2,1,20)), torch.ones((2,10,2))\n",
    "# values=torch.arange(40,dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)\n",
    "# vaild_lens=torch.tensor([2,6])\n",
    "\n",
    "queries = torch.normal(0,1,(2,1,2))\n",
    "attention=DotProductAttention(dropout=0.5, debug=True)\n",
    "attention.eval()\n",
    "attention(queries,keys,values,vaild_lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"101.818906pt\" version=\"1.1\" viewBox=\"0 0 186.99575 101.818906\" width=\"186.99575pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <metadata>\r\n  <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\r\n   <cc:Work>\r\n    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\r\n    <dc:date>2021-11-27T20:06:29.096822</dc:date>\r\n    <dc:format>image/svg+xml</dc:format>\r\n    <dc:creator>\r\n     <cc:Agent>\r\n      <dc:title>Matplotlib v3.3.4, https://matplotlib.org/</dc:title>\r\n     </cc:Agent>\r\n    </dc:creator>\r\n   </cc:Work>\r\n  </rdf:RDF>\r\n </metadata>\r\n <defs>\r\n  <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n  <g id=\"patch_1\">\r\n   <path d=\"M -0 101.818906 \r\nL 186.99575 101.818906 \r\nL 186.99575 0 \r\nL -0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n  </g>\r\n  <g id=\"axes_1\">\r\n   <g id=\"patch_2\">\r\n    <path d=\"M 34.240625 59.13 \r\nL 145.840625 59.13 \r\nL 145.840625 36.81 \r\nL 34.240625 36.81 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n   </g>\r\n   <g clip-path=\"url(#p2a4aafd979)\">\r\n    <image height=\"23\" id=\"imagee9e8e9105d\" transform=\"scale(1 -1)translate(0 -23)\" width=\"112\" x=\"34.240625\" xlink:href=\"data:image/png;base64,\r\niVBORw0KGgoAAAANSUhEUgAAAHAAAAAXCAYAAADTEcupAAAAeElEQVR4nO3YsQmAQBAF0X+aWp8tGFqDXZjZnC2YCV5uJnIcA/MKWBaGTbbcx/ZESZJxXnuv8NnQewH9Y0A4A8IZEM6AcAaEMyCcAeEMCGdAOAPClSVTk1/ofp0txurFC4QzIJwB4QwIZ0A4A8IZEM6AcAaEMyBcBc2IB/NA+hblAAAAAElFTkSuQmCC\" y=\"-36.13\"/>\r\n   </g>\r\n   <g id=\"matplotlib.axis_1\">\r\n    <g id=\"xtick_1\">\r\n     <g id=\"line2d_1\">\r\n      <defs>\r\n       <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=\"m96229da96a\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n      </defs>\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"39.820625\" xlink:href=\"#m96229da96a\" y=\"59.13\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_1\">\r\n      <!-- 0 -->\r\n      <g transform=\"translate(36.639375 73.728437)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 31.78125 66.40625 \r\nQ 24.171875 66.40625 20.328125 58.90625 \r\nQ 16.5 51.421875 16.5 36.375 \r\nQ 16.5 21.390625 20.328125 13.890625 \r\nQ 24.171875 6.390625 31.78125 6.390625 \r\nQ 39.453125 6.390625 43.28125 13.890625 \r\nQ 47.125 21.390625 47.125 36.375 \r\nQ 47.125 51.421875 43.28125 58.90625 \r\nQ 39.453125 66.40625 31.78125 66.40625 \r\nz\r\nM 31.78125 74.21875 \r\nQ 44.046875 74.21875 50.515625 64.515625 \r\nQ 56.984375 54.828125 56.984375 36.375 \r\nQ 56.984375 17.96875 50.515625 8.265625 \r\nQ 44.046875 -1.421875 31.78125 -1.421875 \r\nQ 19.53125 -1.421875 13.0625 8.265625 \r\nQ 6.59375 17.96875 6.59375 36.375 \r\nQ 6.59375 54.828125 13.0625 64.515625 \r\nQ 19.53125 74.21875 31.78125 74.21875 \r\nz\r\n\" id=\"DejaVuSans-48\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"xtick_2\">\r\n     <g id=\"line2d_2\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"95.620625\" xlink:href=\"#m96229da96a\" y=\"59.13\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_2\">\r\n      <!-- 5 -->\r\n      <g transform=\"translate(92.439375 73.728437)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 10.796875 72.90625 \r\nL 49.515625 72.90625 \r\nL 49.515625 64.59375 \r\nL 19.828125 64.59375 \r\nL 19.828125 46.734375 \r\nQ 21.96875 47.46875 24.109375 47.828125 \r\nQ 26.265625 48.1875 28.421875 48.1875 \r\nQ 40.625 48.1875 47.75 41.5 \r\nQ 54.890625 34.8125 54.890625 23.390625 \r\nQ 54.890625 11.625 47.5625 5.09375 \r\nQ 40.234375 -1.421875 26.90625 -1.421875 \r\nQ 22.3125 -1.421875 17.546875 -0.640625 \r\nQ 12.796875 0.140625 7.71875 1.703125 \r\nL 7.71875 11.625 \r\nQ 12.109375 9.234375 16.796875 8.0625 \r\nQ 21.484375 6.890625 26.703125 6.890625 \r\nQ 35.15625 6.890625 40.078125 11.328125 \r\nQ 45.015625 15.765625 45.015625 23.390625 \r\nQ 45.015625 31 40.078125 35.4375 \r\nQ 35.15625 39.890625 26.703125 39.890625 \r\nQ 22.75 39.890625 18.8125 39.015625 \r\nQ 14.890625 38.140625 10.796875 36.28125 \r\nz\r\n\" id=\"DejaVuSans-53\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-53\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"text_3\">\r\n     <!-- Keys -->\r\n     <g transform=\"translate(78.371094 87.406562)scale(0.1 -0.1)\">\r\n      <defs>\r\n       <path d=\"M 9.8125 72.90625 \r\nL 19.671875 72.90625 \r\nL 19.671875 42.09375 \r\nL 52.390625 72.90625 \r\nL 65.09375 72.90625 \r\nL 28.90625 38.921875 \r\nL 67.671875 0 \r\nL 54.6875 0 \r\nL 19.671875 35.109375 \r\nL 19.671875 0 \r\nL 9.8125 0 \r\nz\r\n\" id=\"DejaVuSans-75\"/>\r\n       <path d=\"M 56.203125 29.59375 \r\nL 56.203125 25.203125 \r\nL 14.890625 25.203125 \r\nQ 15.484375 15.921875 20.484375 11.0625 \r\nQ 25.484375 6.203125 34.421875 6.203125 \r\nQ 39.59375 6.203125 44.453125 7.46875 \r\nQ 49.3125 8.734375 54.109375 11.28125 \r\nL 54.109375 2.78125 \r\nQ 49.265625 0.734375 44.1875 -0.34375 \r\nQ 39.109375 -1.421875 33.890625 -1.421875 \r\nQ 20.796875 -1.421875 13.15625 6.1875 \r\nQ 5.515625 13.8125 5.515625 26.8125 \r\nQ 5.515625 40.234375 12.765625 48.109375 \r\nQ 20.015625 56 32.328125 56 \r\nQ 43.359375 56 49.78125 48.890625 \r\nQ 56.203125 41.796875 56.203125 29.59375 \r\nz\r\nM 47.21875 32.234375 \r\nQ 47.125 39.59375 43.09375 43.984375 \r\nQ 39.0625 48.390625 32.421875 48.390625 \r\nQ 24.90625 48.390625 20.390625 44.140625 \r\nQ 15.875 39.890625 15.1875 32.171875 \r\nz\r\n\" id=\"DejaVuSans-101\"/>\r\n       <path d=\"M 32.171875 -5.078125 \r\nQ 28.375 -14.84375 24.75 -17.8125 \r\nQ 21.140625 -20.796875 15.09375 -20.796875 \r\nL 7.90625 -20.796875 \r\nL 7.90625 -13.28125 \r\nL 13.1875 -13.28125 \r\nQ 16.890625 -13.28125 18.9375 -11.515625 \r\nQ 21 -9.765625 23.484375 -3.21875 \r\nL 25.09375 0.875 \r\nL 2.984375 54.6875 \r\nL 12.5 54.6875 \r\nL 29.59375 11.921875 \r\nL 46.6875 54.6875 \r\nL 56.203125 54.6875 \r\nz\r\n\" id=\"DejaVuSans-121\"/>\r\n       <path d=\"M 44.28125 53.078125 \r\nL 44.28125 44.578125 \r\nQ 40.484375 46.53125 36.375 47.5 \r\nQ 32.28125 48.484375 27.875 48.484375 \r\nQ 21.1875 48.484375 17.84375 46.4375 \r\nQ 14.5 44.390625 14.5 40.28125 \r\nQ 14.5 37.15625 16.890625 35.375 \r\nQ 19.28125 33.59375 26.515625 31.984375 \r\nL 29.59375 31.296875 \r\nQ 39.15625 29.25 43.1875 25.515625 \r\nQ 47.21875 21.78125 47.21875 15.09375 \r\nQ 47.21875 7.46875 41.1875 3.015625 \r\nQ 35.15625 -1.421875 24.609375 -1.421875 \r\nQ 20.21875 -1.421875 15.453125 -0.5625 \r\nQ 10.6875 0.296875 5.421875 2 \r\nL 5.421875 11.28125 \r\nQ 10.40625 8.6875 15.234375 7.390625 \r\nQ 20.0625 6.109375 24.8125 6.109375 \r\nQ 31.15625 6.109375 34.5625 8.28125 \r\nQ 37.984375 10.453125 37.984375 14.40625 \r\nQ 37.984375 18.0625 35.515625 20.015625 \r\nQ 33.0625 21.96875 24.703125 23.78125 \r\nL 21.578125 24.515625 \r\nQ 13.234375 26.265625 9.515625 29.90625 \r\nQ 5.8125 33.546875 5.8125 39.890625 \r\nQ 5.8125 47.609375 11.28125 51.796875 \r\nQ 16.75 56 26.8125 56 \r\nQ 31.78125 56 36.171875 55.265625 \r\nQ 40.578125 54.546875 44.28125 53.078125 \r\nz\r\n\" id=\"DejaVuSans-115\"/>\r\n      </defs>\r\n      <use xlink:href=\"#DejaVuSans-75\"/>\r\n      <use x=\"60.576172\" xlink:href=\"#DejaVuSans-101\"/>\r\n      <use x=\"122.099609\" xlink:href=\"#DejaVuSans-121\"/>\r\n      <use x=\"181.279297\" xlink:href=\"#DejaVuSans-115\"/>\r\n     </g>\r\n    </g>\r\n   </g>\r\n   <g id=\"matplotlib.axis_2\">\r\n    <g id=\"ytick_1\">\r\n     <g id=\"line2d_3\">\r\n      <defs>\r\n       <path d=\"M 0 0 \r\nL -3.5 0 \r\n\" id=\"m670ee96195\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n      </defs>\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"34.240625\" xlink:href=\"#m670ee96195\" y=\"42.39\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_4\">\r\n      <!-- 0 -->\r\n      <g transform=\"translate(20.878125 46.189219)scale(0.1 -0.1)\">\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_2\">\r\n     <g id=\"line2d_4\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"34.240625\" xlink:href=\"#m670ee96195\" y=\"53.55\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_5\">\r\n      <!-- 1 -->\r\n      <g transform=\"translate(20.878125 57.349219)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 12.40625 8.296875 \r\nL 28.515625 8.296875 \r\nL 28.515625 63.921875 \r\nL 10.984375 60.40625 \r\nL 10.984375 69.390625 \r\nL 28.421875 72.90625 \r\nL 38.28125 72.90625 \r\nL 38.28125 8.296875 \r\nL 54.390625 8.296875 \r\nL 54.390625 0 \r\nL 12.40625 0 \r\nz\r\n\" id=\"DejaVuSans-49\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-49\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"text_6\">\r\n     <!-- Queries -->\r\n     <g transform=\"translate(14.798437 67.277031)rotate(-90)scale(0.1 -0.1)\">\r\n      <defs>\r\n       <path d=\"M 39.40625 66.21875 \r\nQ 28.65625 66.21875 22.328125 58.203125 \r\nQ 16.015625 50.203125 16.015625 36.375 \r\nQ 16.015625 22.609375 22.328125 14.59375 \r\nQ 28.65625 6.59375 39.40625 6.59375 \r\nQ 50.140625 6.59375 56.421875 14.59375 \r\nQ 62.703125 22.609375 62.703125 36.375 \r\nQ 62.703125 50.203125 56.421875 58.203125 \r\nQ 50.140625 66.21875 39.40625 66.21875 \r\nz\r\nM 53.21875 1.3125 \r\nL 66.21875 -12.890625 \r\nL 54.296875 -12.890625 \r\nL 43.5 -1.21875 \r\nQ 41.890625 -1.3125 41.03125 -1.359375 \r\nQ 40.1875 -1.421875 39.40625 -1.421875 \r\nQ 24.03125 -1.421875 14.8125 8.859375 \r\nQ 5.609375 19.140625 5.609375 36.375 \r\nQ 5.609375 53.65625 14.8125 63.9375 \r\nQ 24.03125 74.21875 39.40625 74.21875 \r\nQ 54.734375 74.21875 63.90625 63.9375 \r\nQ 73.09375 53.65625 73.09375 36.375 \r\nQ 73.09375 23.6875 67.984375 14.640625 \r\nQ 62.890625 5.609375 53.21875 1.3125 \r\nz\r\n\" id=\"DejaVuSans-81\"/>\r\n       <path d=\"M 8.5 21.578125 \r\nL 8.5 54.6875 \r\nL 17.484375 54.6875 \r\nL 17.484375 21.921875 \r\nQ 17.484375 14.15625 20.5 10.265625 \r\nQ 23.53125 6.390625 29.59375 6.390625 \r\nQ 36.859375 6.390625 41.078125 11.03125 \r\nQ 45.3125 15.671875 45.3125 23.6875 \r\nL 45.3125 54.6875 \r\nL 54.296875 54.6875 \r\nL 54.296875 0 \r\nL 45.3125 0 \r\nL 45.3125 8.40625 \r\nQ 42.046875 3.421875 37.71875 1 \r\nQ 33.40625 -1.421875 27.6875 -1.421875 \r\nQ 18.265625 -1.421875 13.375 4.4375 \r\nQ 8.5 10.296875 8.5 21.578125 \r\nz\r\nM 31.109375 56 \r\nz\r\n\" id=\"DejaVuSans-117\"/>\r\n       <path d=\"M 41.109375 46.296875 \r\nQ 39.59375 47.171875 37.8125 47.578125 \r\nQ 36.03125 48 33.890625 48 \r\nQ 26.265625 48 22.1875 43.046875 \r\nQ 18.109375 38.09375 18.109375 28.8125 \r\nL 18.109375 0 \r\nL 9.078125 0 \r\nL 9.078125 54.6875 \r\nL 18.109375 54.6875 \r\nL 18.109375 46.1875 \r\nQ 20.953125 51.171875 25.484375 53.578125 \r\nQ 30.03125 56 36.53125 56 \r\nQ 37.453125 56 38.578125 55.875 \r\nQ 39.703125 55.765625 41.0625 55.515625 \r\nz\r\n\" id=\"DejaVuSans-114\"/>\r\n       <path d=\"M 9.421875 54.6875 \r\nL 18.40625 54.6875 \r\nL 18.40625 0 \r\nL 9.421875 0 \r\nz\r\nM 9.421875 75.984375 \r\nL 18.40625 75.984375 \r\nL 18.40625 64.59375 \r\nL 9.421875 64.59375 \r\nz\r\n\" id=\"DejaVuSans-105\"/>\r\n      </defs>\r\n      <use xlink:href=\"#DejaVuSans-81\"/>\r\n      <use x=\"78.710938\" xlink:href=\"#DejaVuSans-117\"/>\r\n      <use x=\"142.089844\" xlink:href=\"#DejaVuSans-101\"/>\r\n      <use x=\"203.613281\" xlink:href=\"#DejaVuSans-114\"/>\r\n      <use x=\"244.726562\" xlink:href=\"#DejaVuSans-105\"/>\r\n      <use x=\"272.509766\" xlink:href=\"#DejaVuSans-101\"/>\r\n      <use x=\"334.033203\" xlink:href=\"#DejaVuSans-115\"/>\r\n     </g>\r\n    </g>\r\n   </g>\r\n   <g id=\"patch_3\">\r\n    <path d=\"M 34.240625 59.13 \r\nL 34.240625 36.81 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n   <g id=\"patch_4\">\r\n    <path d=\"M 145.840625 59.13 \r\nL 145.840625 36.81 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n   <g id=\"patch_5\">\r\n    <path d=\"M 34.240625 59.13 \r\nL 145.840625 59.13 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n   <g id=\"patch_6\">\r\n    <path d=\"M 34.240625 36.81 \r\nL 145.840625 36.81 \r\n\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n  </g>\r\n  <g id=\"axes_2\">\r\n   <g id=\"patch_7\">\r\n    <path clip-path=\"url(#p7469d24906)\" d=\"M 152.815625 88.74 \r\nL 152.815625 88.421484 \r\nL 152.815625 7.518516 \r\nL 152.815625 7.2 \r\nL 156.892625 7.2 \r\nL 156.892625 7.518516 \r\nL 156.892625 88.421484 \r\nL 156.892625 88.74 \r\nz\r\n\" style=\"fill:#ffffff;stroke:#ffffff;stroke-linejoin:miter;stroke-width:0.01;\"/>\r\n   </g>\r\n   <image height=\"82\" id=\"image6dfb3d125d\" transform=\"scale(1 -1)translate(0 -82)\" width=\"4\" x=\"153\" xlink:href=\"data:image/png;base64,\r\niVBORw0KGgoAAAANSUhEUgAAAAQAAABSCAYAAABzJnWUAAAAnklEQVR4nJ2SOw5CMQwEjfTuf1QqChT/aMnskwxJl9F414ry6Nez7etclmE7qCLIAfRonABZLBy1o2FJEBzxxdD/M7R2NFh7EzpmHBhSm3z1H/YAaDEk1OfVBfCPCXCAEmPJppJRg1Fs0QyCYGhmw4jRIMjkSMgIjBBDapvGfrcrTYwDYINxsweAM8NpvAtgCeDIYqiMSIsYUrs/sdkHBN7KuQYqCJAAAAAASUVORK5CYII=\" y=\"-6\"/>\r\n   <g id=\"matplotlib.axis_3\"/>\r\n   <g id=\"matplotlib.axis_4\">\r\n    <g id=\"ytick_3\">\r\n     <g id=\"line2d_5\">\r\n      <defs>\r\n       <path d=\"M 0 0 \r\nL 3.5 0 \r\n\" id=\"m62556b7519\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n      </defs>\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#m62556b7519\" y=\"88.74\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_7\">\r\n      <!-- 0.0 -->\r\n      <g transform=\"translate(163.892625 92.539219)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 10.6875 12.40625 \r\nL 21 12.40625 \r\nL 21 0 \r\nL 10.6875 0 \r\nz\r\n\" id=\"DejaVuSans-46\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\r\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_4\">\r\n     <g id=\"line2d_6\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#m62556b7519\" y=\"56.124\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_8\">\r\n      <!-- 0.2 -->\r\n      <g transform=\"translate(163.892625 59.923219)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 19.1875 8.296875 \r\nL 53.609375 8.296875 \r\nL 53.609375 0 \r\nL 7.328125 0 \r\nL 7.328125 8.296875 \r\nQ 12.9375 14.109375 22.625 23.890625 \r\nQ 32.328125 33.6875 34.8125 36.53125 \r\nQ 39.546875 41.84375 41.421875 45.53125 \r\nQ 43.3125 49.21875 43.3125 52.78125 \r\nQ 43.3125 58.59375 39.234375 62.25 \r\nQ 35.15625 65.921875 28.609375 65.921875 \r\nQ 23.96875 65.921875 18.8125 64.3125 \r\nQ 13.671875 62.703125 7.8125 59.421875 \r\nL 7.8125 69.390625 \r\nQ 13.765625 71.78125 18.9375 73 \r\nQ 24.125 74.21875 28.421875 74.21875 \r\nQ 39.75 74.21875 46.484375 68.546875 \r\nQ 53.21875 62.890625 53.21875 53.421875 \r\nQ 53.21875 48.921875 51.53125 44.890625 \r\nQ 49.859375 40.875 45.40625 35.40625 \r\nQ 44.1875 33.984375 37.640625 27.21875 \r\nQ 31.109375 20.453125 19.1875 8.296875 \r\nz\r\n\" id=\"DejaVuSans-50\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\r\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-50\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n    <g id=\"ytick_5\">\r\n     <g id=\"line2d_7\">\r\n      <g>\r\n       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"156.892625\" xlink:href=\"#m62556b7519\" y=\"23.508\"/>\r\n      </g>\r\n     </g>\r\n     <g id=\"text_9\">\r\n      <!-- 0.4 -->\r\n      <g transform=\"translate(163.892625 27.307219)scale(0.1 -0.1)\">\r\n       <defs>\r\n        <path d=\"M 37.796875 64.3125 \r\nL 12.890625 25.390625 \r\nL 37.796875 25.390625 \r\nz\r\nM 35.203125 72.90625 \r\nL 47.609375 72.90625 \r\nL 47.609375 25.390625 \r\nL 58.015625 25.390625 \r\nL 58.015625 17.1875 \r\nL 47.609375 17.1875 \r\nL 47.609375 0 \r\nL 37.796875 0 \r\nL 37.796875 17.1875 \r\nL 4.890625 17.1875 \r\nL 4.890625 26.703125 \r\nz\r\n\" id=\"DejaVuSans-52\"/>\r\n       </defs>\r\n       <use xlink:href=\"#DejaVuSans-48\"/>\r\n       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\r\n       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-52\"/>\r\n      </g>\r\n     </g>\r\n    </g>\r\n   </g>\r\n   <g id=\"patch_8\">\r\n    <path d=\"M 152.815625 88.74 \r\nL 152.815625 88.421484 \r\nL 152.815625 7.518516 \r\nL 152.815625 7.2 \r\nL 156.892625 7.2 \r\nL 156.892625 7.518516 \r\nL 156.892625 88.421484 \r\nL 156.892625 88.74 \r\nz\r\n\" style=\"fill:none;stroke:#000000;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n   </g>\r\n  </g>\r\n </g>\r\n <defs>\r\n  <clipPath id=\"p2a4aafd979\">\r\n   <rect height=\"22.32\" width=\"111.6\" x=\"34.240625\" y=\"36.81\"/>\r\n  </clipPath>\r\n  <clipPath id=\"p7469d24906\">\r\n   <rect height=\"81.54\" width=\"4.077\" x=\"152.815625\" y=\"7.2\"/>\r\n  </clipPath>\r\n </defs>\r\n</svg>\r\n",
      "text/plain": [
       "<Figure size 180x180 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 可以将注意力汇聚的输出计算作为值的加权平均，选择不同的注意力评分函数会带来不同的注意力操作。\n",
    "* 当查询和键是不同长度的矢量时，可以使用可加性注意力评分函数。当它们的长度相同时，使用缩放的 \"点——积\" 注意力评分函数的计算效率更高。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch_gpu",
   "language": "python",
   "name": "pytorch_gpu"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
